knitr::opts_chunk$set(fig.align="center")
library(rstanarm)
library(tidyverse)
library(tidybayes)
library(modelr)
library(ggplot2)
library(magrittr)
library(emmeans)
library(bayesplot)
library(brms)
library(gganimate)
theme_set(theme_light())
accuracy_data = read.csv('processed_accuracy_split.csv')
accuracy_data$oracle = as.factor(accuracy_data$oracle)
accuracy_data$search = as.factor(accuracy_data$search)
accuracy_data$dataset = as.factor(accuracy_data$dataset)
models <- list()
draw_data <- list()
search_differences <- list()
oracle_differences <- list()
seed = 12
In our experiement, we used a visualization recommendation algorithm (composed of one search algorithm and one oracle algorithm) to generate visualizations for the user on one of two datasets. We then measured the user’s accuracy on two tasks: Find Extremum and Retrieve Value.
Given a search algorithm (bsf or dfs), an oracle (compassql or dziban), and a dataset (birdstrikes or movies), we would like to predict a user’s chance of answering the Find Extremum task and the Retrieve Value tasks correctly. In addition, we would like to know if the choice of search algorithm and oracle has any meaninful impact on a user’s accuracy for these two tasks. ## Find Extremum: Building a Model for Accuracy Analysis
data_find_extremum <- subset(accuracy_data, task == "1. Find Extremum")
models$find_extremum <- brm(accuracy ~ oracle*search+dataset,
data = data_find_extremum,
prior = c(prior(normal(1, .05), class = Intercept)),
family = bernoulli(link = "logit"),
warmup = 500,
iter = 3000,
chains = 2,
cores=2,
seed=seed,
file = "acc_find_extremum"
)
In the summary table, we want to see Rhat values close to 1.0 and Bulk_ESS in the thousands.
summary(models$find_extremum)
## Family: bernoulli
## Links: mu = logit
## Formula: accuracy ~ oracle * search + dataset
## Data: data_find_extremum (Number of observations: 59)
## Samples: 2 chains, each with iter = 3000; warmup = 500; thin = 1;
## total post-warmup samples = 5000
##
## Population-Level Effects:
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
## Intercept 0.60 0.61 -0.58 1.82 1.00 3020
## oracledziban 0.81 0.87 -0.85 2.51 1.00 2751
## searchdfs 0.40 0.85 -1.24 2.13 1.00 2785
## datasetmovies 0.20 0.61 -0.98 1.35 1.00 4338
## oracledziban:searchdfs -1.17 1.21 -3.48 1.18 1.00 2307
## Tail_ESS
## Intercept 3497
## oracledziban 3517
## searchdfs 2893
## datasetmovies 3393
## oracledziban:searchdfs 2822
##
## Samples were drawn using sampling(NUTS). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).
Trace plots help us check whether there is evidence of non-convergence for model.
plot(models$find_extremum)
In our pairs plots, we want to make sure we don’t have highly correlated parameters (highly correlated parameters means that our model has difficulty differenciating the effect of such parameters).
pairs(models$find_extremum)
A confusion matrix can be used to check our correct classification rate (a useful measure to see how well our model fits our data).
pred <- predict(models$find_extremum, type = "response")
pred <- if_else(pred[,1] > 0.5, 1, 0)
confusion_matrix <- table(pred, pull(data_find_extremum, accuracy))
confusion_matrix
##
## pred 0 1
## 1 5 54
Visualization of parameter effects via draws from our model posterior. The thicker line represents the 95% credible interval, while the thinner, longer line represents the 50% credible interval.
draw_data$find_extremum <- data_find_extremum %>%
add_fitted_draws(models$find_extremum, seed = seed, re_formula = NA) %>%
group_by(search, oracle, dataset, .draw)
draw_data$find_extremum$task <- "1. Find Extremum"
draw_data$find_extremum$condition <- paste(draw_data$find_extremum$oracle, draw_data$find_extremum$search, sep="_")
find_extremum_plot <- draw_data$find_extremum %>% ggplot(aes(
x = .value,
y = condition,
fill = dataset,
alpha = 0.5
)) + stat_halfeye(.width = c(.95, .5)) +
labs(x = "Predicted Accuracy (p_correct)", y = "Oracle/Search Combination")
find_extremum_plot
Since the credible intervals on our plot overlap, we can use mean_qi to get the numeric boundaries for the different intervals.
fit_info <- draw_data$find_extremum %>% group_by(search, oracle, dataset) %>% mean_qi(.value, .width = c(.95, .5))
fit_info
## # A tibble: 16 x 9
## # Groups: search, oracle [4]
## search oracle dataset .value .lower .upper .width .point .interval
## <fct> <fct> <fct> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
## 1 bfs compassql birdstrikes 0.636 0.359 0.861 0.95 mean qi
## 2 bfs compassql movies 0.678 0.423 0.882 0.95 mean qi
## 3 bfs dziban birdstrikes 0.787 0.552 0.938 0.95 mean qi
## 4 bfs dziban movies 0.817 0.605 0.949 0.95 mean qi
## 5 dfs compassql birdstrikes 0.717 0.454 0.902 0.95 mean qi
## 6 dfs compassql movies 0.753 0.519 0.922 0.95 mean qi
## 7 dfs dziban birdstrikes 0.645 0.380 0.865 0.95 mean qi
## 8 dfs dziban movies 0.688 0.441 0.879 0.95 mean qi
## 9 bfs compassql birdstrikes 0.636 0.547 0.732 0.5 mean qi
## 10 bfs compassql movies 0.678 0.600 0.768 0.5 mean qi
## 11 bfs dziban birdstrikes 0.787 0.730 0.862 0.5 mean qi
## 12 bfs dziban movies 0.817 0.766 0.882 0.5 mean qi
## 13 dfs compassql birdstrikes 0.717 0.644 0.802 0.5 mean qi
## 14 dfs compassql movies 0.753 0.685 0.832 0.5 mean qi
## 15 dfs dziban birdstrikes 0.645 0.557 0.739 0.5 mean qi
## 16 dfs dziban movies 0.688 0.613 0.773 0.5 mean qi
## Saving 7 x 5 in image
Next, we want to see if there is any significant difference between the two search algorithms (bfs and dfs) and the two oracles (dzbian and compassql).
Differences in search algorithms:
find_extremum_predictive_data <- data_find_extremum %>%
add_predicted_draws(models$find_extremum, seed = seed, re_formula = NA) %>%
group_by(search, oracle, dataset, .draw)
search_differences$find_extremum <- find_extremum_predictive_data %>%
group_by(search, dataset, .draw) %>%
summarize(accuracy = weighted.mean(.prediction)) %>%
compare_levels(accuracy, by = search) %>%
rename(difference_in_accuracy = accuracy)
## `summarise()` regrouping output by 'search', 'dataset' (override with `.groups` argument)
search_differences$find_extremum$metric = "1. Find Extremum"
search_differences$find_extremum %>%
ggplot(aes(x = difference_in_accuracy, y = metric, fill = dataset, alpha = 0.5)) +
xlab(paste0("Expected Difference in Accuracy (",search_differences$find_extremum[1,'search'],")")) +
ylab("Task")+
stat_halfeye(.width = c(.95, .5)) +
geom_vline(xintercept = 0, linetype = "longdash") +
theme_minimal() +
facet_grid(. ~ dataset)
We can double-check the boundaries of the credible intervals to be sure whether or not the interval contains zero.
search_differences$find_extremum %>% mean_qi(difference_in_accuracy, .width = c(.95, .5))
## # A tibble: 4 x 8
## # Groups: search [1]
## search dataset difference_in_accur… .lower .upper .width .point .interval
## <chr> <fct> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
## 1 dfs - bfs birdstri… -0.0300 -0.429 0.357 0.95 mean qi
## 2 dfs - bfs movies -0.0328 -0.375 0.342 0.95 mean qi
## 3 dfs - bfs birdstri… -0.0300 -0.143 0.143 0.5 mean qi
## 4 dfs - bfs movies -0.0328 -0.171 0.0833 0.5 mean qi
Differences in oracle:
oracle_differences$find_extremum <- find_extremum_predictive_data %>%
group_by(oracle, dataset, .draw) %>%
summarize(accuracy = weighted.mean(.prediction)) %>%
compare_levels(accuracy, by = oracle) %>%
rename(difference_in_accuracy = accuracy)
## `summarise()` regrouping output by 'oracle', 'dataset' (override with `.groups` argument)
oracle_differences$find_extremum$metric = "1. Find Extremum"
oracle_differences$find_extremum %>%
ggplot(aes(x = difference_in_accuracy, y = metric, fill = dataset, alpha = 0.5)) +
xlab(paste0("Expected Difference in Accuracy (",oracle_differences$find_extremum[1,'oracle'],")")) +
ylab("Task")+
stat_halfeye(.width = c(.95, .5)) +
geom_vline(xintercept = 0, linetype = "longdash") +
theme_minimal() +
facet_grid(. ~ dataset)
We can double-check the boundaries of the credible intervals to be sure whether or not the interval contains zero.
oracle_differences$find_extremum %>% mean_qi(difference_in_accuracy, .width = c(.95, .5))
## # A tibble: 4 x 8
## # Groups: oracle [1]
## oracle dataset difference_in_acc… .lower .upper .width .point .interval
## <chr> <fct> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
## 1 dziban - c… birdstr… 0.0412 -0.357 0.429 0.95 mean qi
## 2 dziban - c… movies 0.0326 -0.312 0.404 0.95 mean qi
## 3 dziban - c… birdstr… 0.0412 -0.0714 0.143 0.5 mean qi
## 4 dziban - c… movies 0.0326 -0.108 0.15 0.5 mean qi
data_retrieve_value <- subset(accuracy_data, task == "2. Retrieve Value")
models$retrieve_value <- brm(accuracy ~ oracle*search+dataset,
data = data_retrieve_value,
prior = c(prior(normal(1, .05), class = Intercept)),
family = bernoulli(link = "logit"),
warmup = 500,
iter = 3000,
chains = 2,
cores=2,
seed=seed,
file = "acc_retrieve_value"
)
In the summary table, we want to see Rhat values close to 1.0 and Bulk_ESS in the thousands.
summary(models$retrieve_value)
## Family: bernoulli
## Links: mu = logit
## Formula: accuracy ~ oracle * search + dataset
## Data: data_retrieve_value (Number of observations: 59)
## Samples: 2 chains, each with iter = 3000; warmup = 500; thin = 1;
## total post-warmup samples = 5000
##
## Population-Level Effects:
## Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
## Intercept 0.98 0.60 -0.15 2.19 1.00 3087
## oracledziban 0.40 0.84 -1.23 2.04 1.00 2815
## searchdfs 0.83 0.89 -0.89 2.58 1.00 2933
## datasetmovies -0.51 0.62 -1.76 0.71 1.00 3744
## oracledziban:searchdfs -1.18 1.21 -3.53 1.21 1.00 2424
## Tail_ESS
## Intercept 3083
## oracledziban 2994
## searchdfs 3226
## datasetmovies 3214
## oracledziban:searchdfs 3159
##
## Samples were drawn using sampling(NUTS). For each parameter, Bulk_ESS
## and Tail_ESS are effective sample size measures, and Rhat is the potential
## scale reduction factor on split chains (at convergence, Rhat = 1).
Trace plots help us check whether there is evidence of non-convergence for model.
plot(models$retrieve_value)
In our pairs plots, we want to make sure we don’t have highly correlated parameters (highly correlated parameters means that our model has difficulty differenciating the effect of such parameters).
pairs(models$retrieve_value)
A confusion matrix can be used to check our correct classification rate (a useful measure to see how well our model fits our data).
pred <- predict(models$retrieve_value, type = "response")
pred <- if_else(pred[,1] > 0.5, 1, 0)
confusion_matrix <- table(pred, pull(data_retrieve_value, accuracy))
confusion_matrix
##
## pred 0 1
## 1 5 54
Visualization of parameter effects via draws from our model posterior. The thicker line represents the 95% credible interval, while the thinner, longer line represents the 50% credible interval.
draw_data$retrieve_value <- data_retrieve_value %>%
add_fitted_draws(models$retrieve_value, seed = seed, re_formula = NA) %>%
group_by(search, oracle, dataset, .draw)
draw_data$retrieve_value$task <- "2. Retrieve Value"
draw_data$retrieve_value$condition <- paste(draw_data$retrieve_value$oracle, draw_data$retrieve_value$search, sep="_")
retrieve_value_plot <- draw_data$retrieve_value %>% ggplot(aes(
x = .value,
y = condition,
fill = dataset,
alpha = 0.5
)) + stat_halfeye(.width = c(.95, .5)) +
labs(x = "Predicted Accuracy (p_correct)", y = "Oracle/Search Combination")
retrieve_value_plot
Since the credible intervals on our plot overlap, we can use mean_qi to get the numeric boundaries for the different intervals.
fit_info <- draw_data$retrieve_value %>% group_by(search, oracle, dataset) %>% mean_qi(.value, .width = c(.95, .5))
fit_info
## # A tibble: 16 x 9
## # Groups: search, oracle [4]
## search oracle dataset .value .lower .upper .width .point .interval
## <fct> <fct> <fct> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
## 1 bfs compassql birdstrikes 0.712 0.462 0.899 0.95 mean qi
## 2 bfs compassql movies 0.605 0.325 0.844 0.95 mean qi
## 3 bfs dziban birdstrikes 0.780 0.547 0.935 0.95 mean qi
## 4 bfs dziban movies 0.689 0.426 0.887 0.95 mean qi
## 5 dfs compassql birdstrikes 0.841 0.632 0.958 0.95 mean qi
## 6 dfs compassql movies 0.766 0.514 0.929 0.95 mean qi
## 7 dfs dziban birdstrikes 0.720 0.461 0.905 0.95 mean qi
## 8 dfs dziban movies 0.615 0.340 0.843 0.95 mean qi
## 9 bfs compassql birdstrikes 0.712 0.636 0.797 0.5 mean qi
## 10 bfs compassql movies 0.605 0.516 0.704 0.5 mean qi
## 11 bfs dziban birdstrikes 0.780 0.721 0.857 0.5 mean qi
## 12 bfs dziban movies 0.689 0.612 0.778 0.5 mean qi
## 13 dfs compassql birdstrikes 0.841 0.796 0.903 0.5 mean qi
## 14 dfs compassql movies 0.766 0.702 0.847 0.5 mean qi
## 15 dfs dziban birdstrikes 0.720 0.648 0.804 0.5 mean qi
## 16 dfs dziban movies 0.615 0.527 0.711 0.5 mean qi
## Saving 7 x 5 in image
Next, we want to see if there is any significant difference between the two search algorithms (bfs and dfs) and the two oracles (dzbian and compassql).
Differences in search algorithms:
retrieve_value_predictive_data <- data_retrieve_value %>%
add_predicted_draws(models$retrieve_value, seed = seed, re_formula = NA) %>%
group_by(search, oracle, dataset, .draw)
search_differences$retrieve_value <- retrieve_value_predictive_data %>%
group_by(search, dataset, .draw) %>%
summarize(accuracy = weighted.mean(.prediction)) %>%
compare_levels(accuracy, by = search) %>%
rename(difference_in_accuracy = accuracy)
## `summarise()` regrouping output by 'search', 'dataset' (override with `.groups` argument)
search_differences$retrieve_value$metric = "2. Retrieve Value"
search_differences$retrieve_value %>%
ggplot(aes(x = difference_in_accuracy, y = metric, fill = dataset, alpha = 0.5)) +
xlab(paste0("Expected Difference in Accuracy (",search_differences$retrieve_value[1,'search'],")")) +
ylab("Task")+
stat_halfeye(.width = c(.95, .5)) +
geom_vline(xintercept = 0, linetype = "longdash") +
theme_minimal() +
facet_grid(. ~ dataset)
We can double-check the boundaries of the credible intervals to be sure whether or not the interval contains zero.
search_differences$retrieve_value %>% mean_qi(difference_in_accuracy, .width = c(.95, .5))
## # A tibble: 4 x 8
## # Groups: search [1]
## search dataset difference_in_accu… .lower .upper .width .point .interval
## <chr> <fct> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
## 1 dfs - b… birdstrik… 0.0345 -0.287 0.429 0.95 mean qi
## 2 dfs - b… movies 0.0387 -0.367 0.467 0.95 mean qi
## 3 dfs - b… birdstrik… 0.0345 -0.0714 0.143 0.5 mean qi
## 4 dfs - b… movies 0.0387 -0.104 0.158 0.5 mean qi
Differences in oracle:
oracle_differences$retrieve_value <- retrieve_value_predictive_data %>%
group_by(oracle, dataset, .draw) %>%
summarize(accuracy = weighted.mean(.prediction)) %>%
compare_levels(accuracy, by = oracle) %>%
rename(difference_in_accuracy = accuracy)
## `summarise()` regrouping output by 'oracle', 'dataset' (override with `.groups` argument)
oracle_differences$retrieve_value$metric = "2. Retrieve Value"
oracle_differences$retrieve_value %>%
ggplot(aes(x = difference_in_accuracy, y = metric, fill = dataset, alpha = 0.5)) +
xlab(paste0("Expected Difference in Accuracy (",oracle_differences$retrieve_value[1,'oracle'],")")) +
ylab("Task")+
stat_halfeye(.width = c(.95, .5)) +
geom_vline(xintercept = 0, linetype = "longdash") +
theme_minimal() +
facet_grid(. ~ dataset)
We can double-check the boundaries of the credible intervals to be sure whether or not the interval contains zero.
oracle_differences$retrieve_value %>% mean_qi(difference_in_accuracy, .width = c(.95, .5))
## # A tibble: 4 x 8
## # Groups: oracle [1]
## oracle dataset difference_in_acc… .lower .upper .width .point .interval
## <chr> <fct> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
## 1 dziban - c… birdstri… -0.0240 -0.357 0.357 0.95 mean qi
## 2 dziban - c… movies -0.0444 -0.429 0.346 0.95 mean qi
## 3 dziban - c… birdstri… -0.0240 -0.143 0.0714 0.5 mean qi
## 4 dziban - c… movies -0.0444 -0.175 0.0875 0.5 mean qi
Putting the all of the plots for search algorithm differences on the same plot:
combined_search_differences <- rbind(search_differences$find_extremum, search_differences$retrieve_value)
search_differences_plot <- combined_search_differences %>%
ggplot(aes(x = difference_in_accuracy, y = metric, fill = dataset, alpha = 0.5)) +
xlab(paste0("Expected Difference in Accuracy (",combined_search_differences[1,'search'],")")) +
ylab("Task")+
stat_halfeye(.width = c(.95, .5)) +
geom_vline(xintercept = 0, linetype = "longdash") +
theme_minimal() +
facet_grid(. ~ dataset)
search_differences_plot
search_intervals <- combined_search_differences %>% group_by(search, dataset, metric) %>% mean_qi(difference_in_accuracy, .width = c(.95, .5))
search_intervals
## # A tibble: 8 x 9
## # Groups: search, dataset [2]
## search dataset metric difference_in_a… .lower .upper .width .point .interval
## <chr> <fct> <chr> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
## 1 dfs - … birdst… 1. Fi… -0.0300 -0.429 0.357 0.95 mean qi
## 2 dfs - … birdst… 2. Re… 0.0345 -0.287 0.429 0.95 mean qi
## 3 dfs - … movies 1. Fi… -0.0328 -0.375 0.342 0.95 mean qi
## 4 dfs - … movies 2. Re… 0.0387 -0.367 0.467 0.95 mean qi
## 5 dfs - … birdst… 1. Fi… -0.0300 -0.143 0.143 0.5 mean qi
## 6 dfs - … birdst… 2. Re… 0.0345 -0.0714 0.143 0.5 mean qi
## 7 dfs - … movies 1. Fi… -0.0328 -0.171 0.0833 0.5 mean qi
## 8 dfs - … movies 2. Re… 0.0387 -0.104 0.158 0.5 mean qi
Putting the all of the plots for oracle differences on the same plot:
combined_oracle_differences <- rbind(oracle_differences$find_extremum, oracle_differences$retrieve_value)
oracle_differences_plot <- combined_oracle_differences %>%
ggplot(aes(x = difference_in_accuracy, y = metric, fill = dataset, alpha = 0.5)) +
xlab(paste0("Expected Difference in Accuracy (",combined_oracle_differences[1,'oracle'],")")) +
ylab("Task")+
stat_halfeye(.width = c(.95, .5)) +
geom_vline(xintercept = 0, linetype = "longdash") +
theme_minimal() +
facet_grid(. ~ dataset)
oracle_differences_plot
oracle_intervals <- combined_oracle_differences %>% group_by(oracle, dataset, metric) %>% mean_qi(difference_in_accuracy, .width = c(.95, .5))
oracle_intervals
## # A tibble: 8 x 9
## # Groups: oracle, dataset [2]
## oracle dataset metric difference_in_a… .lower .upper .width .point .interval
## <chr> <fct> <chr> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
## 1 dziban… birdst… 1. Fi… 0.0412 -0.357 0.429 0.95 mean qi
## 2 dziban… birdst… 2. Re… -0.0240 -0.357 0.357 0.95 mean qi
## 3 dziban… movies 1. Fi… 0.0326 -0.312 0.404 0.95 mean qi
## 4 dziban… movies 2. Re… -0.0444 -0.429 0.346 0.95 mean qi
## 5 dziban… birdst… 1. Fi… 0.0412 -0.0714 0.143 0.5 mean qi
## 6 dziban… birdst… 2. Re… -0.0240 -0.143 0.0714 0.5 mean qi
## 7 dziban… movies 1. Fi… 0.0326 -0.108 0.15 0.5 mean qi
## 8 dziban… movies 2. Re… -0.0444 -0.175 0.0875 0.5 mean qi